#! /usr/bin/env python3
import argparse
import collections
import json
import re
import textwrap

import numpy as np
from tabulate import tabulate

from megengine.utils.profile_analyzer import (
    NonExistNum,
    ProfileAnalyzer,
    TimeFuncHelper,
)


def _tabulate_ml(tab, **kwargs):
    r"""Tabulate profile output with multi-line support."""
    new_tab = []
    new_tab_is_row = []
    for row in tab:
        col_lines = [str(i).split("\n") for i in row]
        max_nr_line = max(map(len, col_lines))
        new_tab_is_row.append(True)
        if max_nr_line > 1:
            new_tab_is_row.extend([False] * (max_nr_line - 1))
            for i in col_lines:
                if len(i) < max_nr_line:
                    i.extend([""] * (max_nr_line - len(i)))
            new_tab.extend(zip(*col_lines))
        else:
            new_tab.append(row)

    assert len(new_tab_is_row) == len(new_tab)
    ret = [i + "\n" for i in tabulate(new_tab, **kwargs).split("\n")]
    for idx, val in enumerate(new_tab_is_row):
        if not val:
            ret[idx * 2 + 2] = ""
    return "".join(ret)[:-1]


def _tabulate_confluence(tab, **kwargs):
    r"""Tabulate profile output."""
    kwargs.pop("tablefmt", None)
    s = tabulate(tab, tablefmt="orgtbl", **kwargs)
    lines = s.split("\n")
    lines[1] = lines[1].replace("+", "|")
    return "\n".join(lines)


def main(passed_args=None):  # pylint: disable=too-many-statements
    is_profile_from_ocl = False
    r"""Analyses profile info from :mod:`~.utils.profile_analyzer` .
    Run this file with ``--help`` to get more usage.
    """
    parser = argparse.ArgumentParser(
        description="analyze analyzer result",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("dump")
    parser.add_argument(
        "-t",
        "--top",
        type=int,
        default=3,
        help="number of most time-consuming operators to print",
    )
    parser.add_argument(
        "--type", action="append", help="filter oprs in the top list by type"
    )
    parser.add_argument(
        "--aggregate-by",
        default=None,
        choices=["type"],
        help="aggragate profiling result by",
    )
    parser.add_argument(
        "--opr-name", help="filter oprs in the top list by regex of name"
    )
    parser.add_argument(
        "--input-dtype", type=str, help="filter oprs in the top list by input dtype"
    )
    parser.add_argument(
        "--top-end-key",
        default="end",
        choices=["end", "kern"],
        help="how time in top is calculated; end corresponds "
        "to total device time, and kern corresponds to only "
        "wait time",
    )
    parser.add_argument(
        "--aggregate",
        default=None,
        help="aggregate operations",
        choices=["max", "min", "sum", "mean"],
    )
    parser.add_argument(
        "--order-by",
        default="time",
        help="sort result according to given column; the param can be "
        "<col_name> or +<col_name>, meaning sorting in descending or "
        "ascending order respectively",
    )
    parser.add_argument(
        "--copy-time", action="store_true", help="show copy time related result"
    )
    parser.add_argument(
        "--min-time",
        type=float,
        default=float("-inf"),
        help="minimal time of a result to be printed",
    )
    parser.add_argument(
        "--max-time",
        type=float,
        default=float("inf"),
        help="maximal time of a result to be printed",
    )
    parser.add_argument(
        "--show-host", action="store_true", help="show host profiling info"
    )
    parser.add_argument(
        "--dump-only-opr",
        action="store_true",
        help="only dump operator info as plaintext; useful "
        "for diff between two filtered profile results",
    )
    parser.add_argument(
        "--confluence",
        "--wiki",
        action="store_true",
        help="output confluence-markdown-compatible table",
    )
    parser.add_argument(
        "--print-only",
        choices={"summary", "device", "host"},
        help="print only chosen info",
    )

    args = parser.parse_args(passed_args)

    opr_filters = []
    if args.type:
        opr_filters.append(lambda o, a, b: o["type"] in args.type)
    if args.opr_name:
        opr_filters.append(
            lambda o, a, b, r=re.compile(args.opr_name): r.match(o["name"])
        )
    if args.input_dtype:
        opr_filters.append(
            lambda o, a, b: any(
                [i["mem_plan"]["layout"]["dtype"] == args.input_dtype for i in a]
            )
        )
    if not opr_filters:

        def opr_filter(o, a, b):  # pylint: disable=unused-argument
            return True

    else:

        def opr_filter(o, a, b):
            return all(i(o, a, b) for i in opr_filters)

    with open(args.dump) as fin:
        dump = json.load(fin)
        if str(dump).find("opencl") != -1:
            is_profile_from_ocl = True

    analyzer = ProfileAnalyzer(dump, opr_filter)
    analyzer_tot = ProfileAnalyzer(dump, lambda _, __, ___: True)

    def summary():
        device_end_func = TimeFuncHelper.eval_time_func("device", "end", np.max)
        device_kern_func = TimeFuncHelper.eval_time_func("device", "kern", np.max)
        host_end_func = TimeFuncHelper.eval_time_func("host", "end", np.max)

        def get_tot_time(func):
            rec = analyzer_tot.select(func, aggregate=np.sum)
            if not rec:
                return "N/A"
            rec = rec[0]
            return rec.time

        tab = []
        tot_dev_time = get_tot_time(device_end_func)
        tot_host_time = get_tot_time(host_end_func)
        tab.append(("total device time", tot_dev_time))
        # check device time
        if 0 == tot_dev_time:
            msg = (
                "please call mgb::CompNode::enable_opencl_profile(true) before profile at c/c++ code"
                if is_profile_from_ocl
                else "please raise a issue for Engine"
            )
            assert 0 != tot_dev_time, "total device time should not be 0, {}".format(
                msg
            )
        tab.append(("total host time", tot_host_time))
        if args.copy_time:

            def fmt(a, b):
                a = a[0]
                b = b[0]
                return "tot={:.4f} avg={:.4f}".format(a.time, b.time)

            tab.append(
                (
                    "copy time",
                    fmt(
                        analyzer.select(
                            device_end_func,
                            lambda opr: opr.opr_info["type"] == "Copy",
                            aggregate=np.sum,
                        ),
                        analyzer.select(
                            device_end_func,
                            lambda opr: opr.opr_info["type"] == "Copy",
                            aggregate=np.mean,
                        ),
                    ),
                )
            )
            tab.append(
                (
                    "copy wait time",
                    fmt(
                        analyzer.select(
                            device_kern_func,
                            lambda opr: opr.opr_info["type"] == "Copy",
                            aggregate=np.sum,
                        ),
                        analyzer.select(
                            device_kern_func,
                            lambda opr: opr.opr_info["type"] == "Copy",
                            aggregate=np.mean,
                        ),
                    ),
                )
            )

        if args.confluence:
            tab_str = _tabulate_confluence(tab, headers=["name", "value"])
        else:
            tab_str = tabulate(tab)

        return tab_str, tot_dev_time, tot_host_time

    def prof_details(prof_type, tot_time):
        tab = []

        def func(
            opr,
            *,
            f0=TimeFuncHelper.eval_time_func(prof_type, args.top_end_key, np.max)
        ):
            t = f0(opr)
            if t is not None and (t < args.min_time or t > args.max_time):
                return None
            return t

        records = analyzer.select(
            func,
            aggregate=args.aggregate,
            aggregate_by=args.aggregate_by,
            top_k=args.top,
            sort_by=args.order_by,
        )

        if args.dump_only_opr:
            ret = []
            for i in records:
                ret.append(" ".join(i.info.values()))
            return "\n".join(ret)

        def format_shapes(shapes, layouts=None, sep="\n"):
            if isinstance(shapes, NonExistNum) or shapes is None:
                return repr(shapes)
            if layouts is None:
                layouts = [None] * len(shapes)

            comp = []
            for i, j in zip(shapes, layouts):
                i = "{" + ",".join(map(str, i)) + "}"
                if j:
                    i += "\n -[" + ",".join(map(str, j)) + "]"
                comp.append(i)
            return sep.join(comp)

        def fix_num_and_find_unit(x, base):
            if isinstance(x, NonExistNum) or (
                isinstance(x, float) and not np.isfinite(x)
            ):
                return x, ""
            unit = iter(["", "K", "M", "G", "T", "P"])
            while x >= base:
                x /= base
                next(unit)
            return x, next(unit)

        def get_number_with_unit(num, unit, base, sep="\n"):
            num, unit_prefix = fix_num_and_find_unit(num, base)
            if isinstance(unit, list):
                unit = unit[int(unit_prefix != "")]
            return ("{:.2f}" + sep + "{}{}").format(num, unit_prefix, unit)

        if args.confluence:
            rows = []
            cum_time = 0

            max_time = max([r.time for r in records])
            max_bandwidth = max([r.bandwidth for r in records])
            max_flops = max(
                [r.flops for r in records if not isinstance(r.flops, NonExistNum)]
            )

            bar_length = 15
            for idx, record in enumerate(records):
                cum_time += record.time

                opr_info = [("opr " + k, v) for k, v in record.info.items()]

                row = collections.OrderedDict(
                    [
                        ("#", idx),
                        ("time", "{:.3}".format(record.time)),
                        ("ratio", "{:.1f}%".format(record.time / tot_time * 100)),
                        ("time bar", "#" * int(record.time / max_time * bar_length)),
                        ("cum-time", cum_time),
                        ("cum-time ratio", cum_time / tot_time),
                    ]
                    + opr_info
                    + [
                        (
                            "computation (MFLO)",
                            "{:.1f}".format(record.computation / 1000 ** 2),
                        ),
                        ("MFLOPS", "{:.1f}".format(record.flops / 1000 ** 2)),
                        (
                            "MFLOPS-bar",
                            ""
                            if isinstance(record.flops, NonExistNum)
                            else ("#" * int(record.flops / max_flops * bar_length)),
                        ),
                        ("memory (MB)", "{:.1f}".format(record.memory / 1024 ** 2)),
                        (
                            "bandwidth (MiB/s)",
                            "{:.1f}".format(record.bandwidth / 1024 ** 2),
                        ),
                        (
                            "bandwidth bar",
                            "#" * int(record.bandwidth / max_bandwidth * bar_length),
                        ),
                        (
                            "in_shapes",
                            format_shapes(
                                record.in_shapes, record.in_layouts, sep=", "
                            ),
                        ),
                        ("out_shapes", format_shapes(record.out_shapes, sep=", ")),
                    ]
                )
                rows.append(row)
            headers = list(rows[0].keys())
            tab = [[row[i] for i in headers] for row in rows]

            return _tabulate_confluence(tab, headers=headers)

        else:
            cum_time = 0
            for idx, record in enumerate(records):
                cum_time += record.time
                tab.append(
                    (
                        "#{}\n{:.3}\n{:.1f}%".format(
                            idx, record.time, record.time / tot_time * 100
                        ),
                        "{:.3}\n{:.1f}%".format(cum_time, cum_time / tot_time * 100),
                        "\n".join(
                            "\n-  ".join(textwrap.wrap(str(i), width=30))
                            for i in record.info.values()
                        ),
                        get_number_with_unit(record.computation, "FLO", 1000),
                        get_number_with_unit(record.flops, "FLOPS", 1000),
                        get_number_with_unit(record.memory, ["byte", "iB"], 1024),
                        get_number_with_unit(
                            record.bandwidth, ["byte/s", "iB/s"], 1024
                        ),
                        format_shapes(record.in_shapes, record.in_layouts),
                        format_shapes(record.out_shapes),
                    )
                )
            return _tabulate_ml(
                tab,
                headers=[
                    "{} self time".format(prof_type),
                    "cumulative",
                    "operator info",
                    "computation",
                    "FLOPS",
                    "memory",
                    "bandwidth",
                    "in_shapes",
                    "out_shapes",
                ],
                tablefmt="fancy_grid",
            )

    summary_tab, tot_dev_time, tot_host_time = summary()
    if args.print_only:
        print(
            {
                "summary": lambda: summary_tab,
                "device": lambda: prof_details("device", tot_dev_time),
                "host": lambda: prof_details("host", tot_host_time),
            }[args.print_only]()
        )
    else:
        print(summary_tab)
        print()
        print(prof_details("device", tot_dev_time))
        if args.show_host:
            print()
            print(prof_details("host", tot_host_time))


if __name__ == "__main__":
    main()
